#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Blocks for Polar decoding such as successive cancellation (SC), successive
cancellation list (SCL) and iterative belief propagation (BP) decoding."""

import tensorflow as tf
import numpy as np
import warnings
from sionna.phy import Block
from sionna.phy.fec.crc import CRCDecoder, CRCEncoder
from sionna.phy.fec.polar.encoding import Polar5GEncoder
import numbers

class PolarSCDecoder(Block):
    """Successive cancellation (SC) decoder [Arikan_Polar]_ for Polar codes and
    Polar-like codes.

    Parameters
    ----------
    frozen_pos: ndarray
        Array of `int` defining the ``n-k`` indices of the frozen positions.

    n: int
        Defining the codeword length.

    precision : `None` (default) | 'single' | 'double'
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.

    Input
    -----
    llr_ch: [...,n], tf.float
        Tensor containing the channel LLR values (as logits).

    Output
    ------
    : [...,k], tf.float
        Tensor  containing hard-decided estimations of all ``k``
        information bits.

    Note
    ----
    This block implements the SC decoder as described in
    [Arikan_Polar]_. However, the implementation follows the `recursive
    tree` [Gross_Fast_SCL]_ terminology and combines nodes for increased
    throughputs without changing the outcome of the algorithm.

    As commonly done, we assume frozen bits are set to `0`. Please note
    that - although its practical relevance is only little - setting frozen
    bits to `1` may result in `affine` codes instead of linear code as the
    `all-zero` codeword is not necessarily part of the code any more.
    """

    def __init__(self, frozen_pos, n, precision=None, **kwargs):

        super().__init__(precision=precision, **kwargs)

        # assert error if r>1 or k, n are negative
        if not isinstance(n, numbers.Number):
            raise TypeError( "n must be a number.")
        n = int(n) # n can be float (e.g. as result of n=k*r)

        if not np.issubdtype(frozen_pos.dtype, int):
            raise TypeError("frozen_pos contains non int.")
        if len(frozen_pos)>n:
            msg = "Num. of elements in frozen_pos cannot be greater than n."
            raise ValueError(msg)
        if np.log2(n)!=int(np.log2(n)):
            raise ValueError("n must be a power of 2.")

        # store internal attributes
        self._n = n
        self._frozen_pos = frozen_pos
        self._k = self._n - len(self._frozen_pos)
        self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
        if self._k!=len(self._info_pos):
            msg = "Internal error: invalid info_pos generated."
            raise ArithmeticError(msg)

        self._llr_max = 30. # internal max LLR value (uncritical for SC dec)
        # and create a frozen bit vector for simpler encoding
        self._frozen_ind = np.zeros(self._n)
        self._frozen_ind[self._frozen_pos] = 1

        # enable graph pruning
        self._use_fast_sc = False

    ###############################
    # Public methods and properties
    ###############################

    @property
    def n(self):
        """Codeword length"""
        return self._n

    @property
    def k(self):
        """Number of information bits"""
        return self._k

    @property
    def frozen_pos(self):
        """Frozen positions for Polar decoding"""
        return self._frozen_pos

    @property
    def info_pos(self):
        """Information bit positions for Polar encoding"""
        return self._info_pos

    @property
    def llr_max(self):
        """Maximum LLR value for internal calculations"""
        return self._llr_max

    #################
    # Utility methods
    #################

    def _cn_op_tf(self, x, y):
        """Check-node update (boxplus) for LLR inputs.

        Operations are performed element-wise.

        See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
        """
        x_in = tf.clip_by_value(x,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)
        y_in = tf.clip_by_value(y,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)

        # avoid division for numerical stability
        llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
        llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))

        return llr_out

    def _vn_op_tf(self, x, y, u_hat):
        """VN update for LLR inputs."""
        return tf.multiply((1-2*u_hat), x) + y

    def _polar_decode_sc_tf(self, llr_ch, frozen_ind):
        """Recursive SC decoding function.

        Recursively branch decoding tree and split into decoding of `upper`
        and `lower` path until reaching a leaf node.

        The function returns the u_hat decisions at stage `0` and the bit
        decisions of the intermediate stage `s` (i.e., the re-encoded version of
        `u_hat` until the current stage `s`).

        Note:
            This decoder parallelizes over the batch-dimension, i.e., the tree
            is processed for all samples in the batch in parallel. This yields a
            higher throughput, but does not improve the latency.
        """

        # calculate current codeword length
        n = len(frozen_ind)

        # branch if leaf is not reached yet
        if n>1:
            if self._use_fast_sc:
                if np.sum(frozen_ind)==n:
                    #print("rate-0 detected! Length: ", n)
                    u_hat = tf.zeros_like(llr_ch)
                    return u_hat, u_hat

            llr_ch1 = llr_ch[...,0:int(n/2)]
            llr_ch2 = llr_ch[...,int(n/2):]
            frozen_ind1 = frozen_ind[0:int(n/2)]
            frozen_ind2 = frozen_ind[int(n/2):]

            # upper path
            x_llr1_in = self._cn_op_tf(llr_ch1, llr_ch2)

            # and call the decoding function (with upper half)
            u_hat1, u_hat1_up = self._polar_decode_sc_tf(x_llr1_in, frozen_ind1)

            # lower path
            x_llr2_in = self._vn_op_tf(llr_ch1, llr_ch2, u_hat1_up)
            # and call the decoding function again (with lower half)
            u_hat2, u_hat2_up = self._polar_decode_sc_tf(x_llr2_in, frozen_ind2)

            # combine u_hat from both branches
            u_hat = tf.concat([u_hat1, u_hat2], -1)

            # calculate re-encoded version of u_hat at current stage
            # u_hat1_up = tf.math.mod(u_hat1_up + u_hat2_up, 2)
            # combine u_hat via bitwise_xor (more efficient than mod2)
            u_hat1_up_int = tf.cast(u_hat1_up, tf.int8)
            u_hat2_up_int = tf.cast(u_hat2_up, tf.int8)
            u_hat1_up_int = tf.bitwise.bitwise_xor(u_hat1_up_int,
                                                   u_hat2_up_int)
            u_hat1_up = tf.cast(u_hat1_up_int , self.rdtype)
            u_hat_up = tf.concat([u_hat1_up, u_hat2_up], -1)

        else: # if leaf is reached perform basic decoding op (=decision)

            if frozen_ind==1: # position is frozen
                u_hat = tf.expand_dims(tf.zeros_like(llr_ch[:,0]), axis=-1)
                u_hat_up = u_hat
            else: # otherwise hard decide
                u_hat = 0.5 * (1. - tf.sign(llr_ch))
                #remove "exact 0 llrs" leading to u_hat=0.5
                u_hat = tf.where(tf.equal(u_hat, 0.5),
                                 tf.ones_like(u_hat),
                                 u_hat)
                u_hat_up = u_hat
        return u_hat, u_hat_up

    ########################
    # Sionna Block functions
    ########################

    def build(self, input_shape):
        """Check if shape of input is invalid."""

        if input_shape[-1]!=self._n:
            raise ValueError("Invalid input shape.")

    def call(self, llr_ch, /):
        """Successive cancellation (SC) decoding function.

        Performs successive cancellation decoding and returns the estimated
        information bits.

        Args:
            llr_ch (tf.float): Tensor of shape `[...,n]` containing the
                channel LLR values (as logits).

        Returns:
            `tf.float`: Tensor of shape `[...,k]` containing
            hard-decided estimations of all ``k`` information bits.

        Note:
            This function recursively unrolls the SC decoding tree, thus,
            for larger values of ``n`` building the decoding graph can become
            time consuming.
        """

        # Reshape inputs to [-1, n]
        input_shape = llr_ch.shape
        new_shape = [-1, self._n]
        llr_ch = tf.reshape(llr_ch, new_shape)

        llr_ch = -1. * llr_ch # logits are converted into "true" llrs

        # and decode
        u_hat_n, _ = self._polar_decode_sc_tf(llr_ch, self._frozen_ind)

        # and recover the k information bit positions
        u_hat = tf.gather(u_hat_n, self._info_pos, axis=1)

        # and reconstruct input shape
        output_shape = input_shape.as_list()
        output_shape[-1] = self.k
        output_shape[0] = -1 # first dim can be dynamic (None)
        u_hat_reshape = tf.reshape(u_hat, output_shape)
        return u_hat_reshape

class PolarSCLDecoder(Block):
    # pylint: disable=line-too-long
    """Successive cancellation list (SCL) decoder [Tal_SCL]_ for Polar codes
    and Polar-like codes.

    Parameters
    ----------
    frozen_pos: ndarray
        Array of `int` defining the ``n-k`` indices of the frozen positions.

    n: int
        Defining the codeword length.

    list_size: int, (default 8)
        Defines the list size of the decoder.

    crc_degree: str, "CRC24A" | "CRC24B" | "CRC24C" | "CRC16" | "CRC11" | "CRC6"
        Defining the CRC polynomial to be used. Can be any value from
        `{CRC24A, CRC24B, CRC24C, CRC16, CRC11, CRC6}`.

    use_hybrid_sc: `bool`, (default `False`)
        If `True`,  SC decoding is applied and only the codewords with invalid
        CRC are decoded with SCL. This option requires an outer CRC specified
        via ``crc_degree``. Remark: hybrid_sc does not support XLA optimization,
        i.e., `@tf.function(jit_compile=True)`.

    use_fast_scl: `bool`, (default `True`)
        If `True`,  Tree pruning is used to
        reduce the decoding complexity. The output is equivalent to the
        non-pruned version (besides numerical differences).

    cpu_only: `bool`, (default `False`)
        If `True`,  `tf.py_function` embedding is used and the decoder runs on
        the CPU. This option is usually slower, but also more memory efficient
        and in particular, recommended for larger blocklengths. Remark: cpu_only
        does not support XLA optimization `@tf.function(jit_compile=True)`.

    use_scatter: `bool`, (default `False`)
        If `True`,  `tf.tensor_scatter_update` is used for tensor updates. This
        option is usually slower, but more memory efficient.

    ind_iil_inv : None or [k+k_crc], int or tf.int
        Defaults to None. If not `None`, the sequence is used as inverse
        input bit interleaver before evaluating the CRC.
        Remark: this only effects the CRC evaluation but the output
        sequence is not permuted.

    return_crc_status: `bool`, (default `False`)
        If `True`,  the decoder additionally returns the CRC status indicating
        if a codeword was (most likely) correctly recovered. This is only
        available if ``crc_degree`` is not None.

    precision : `None` (default) | 'single' | 'double'
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.

    Input
    -----
    llr_ch: [...,n], tf.float
        Tensor containing the channel LLR values (as logits).

    Output
    ------
    b_hat : [...,k], tf.float
        Binary tensor containing hard-decided estimations of all `k`
        information bits.

    crc_status : [...], tf.bool
        CRC status indicating if a codeword was (most likely) correctly
        recovered. This is only returned if ``return_crc_status`` is True.
        Note that false positives are possible.

    Note
    ----
    This block implements the successive cancellation list (SCL) decoder
    as described in [Tal_SCL]_ but uses LLR-based message updates
    [Stimming_LLR]_. The implementation follows the notation from
    [Gross_Fast_SCL]_, [Hashemi_SSCL]_. If option `use_fast_scl` is active
    tree pruning is used and tree nodes are combined if possible (see
    [Hashemi_SSCL]_ for details).

    Implementing SCL decoding as TensorFlow graph is a difficult task that
    requires several design tradeoffs to match the TF constraints while
    maintaining a reasonable throughput. Thus, the decoder minimizes
    the `control flow` as much as possible, leading to a strong memory
    occupation (e.g., due to full path duplication after each decision).
    For longer code lengths, the complexity of the decoding graph becomes
    large and we recommend to use the `CPU_only` option that uses an
    embedded Numpy decoder. Further, this function recursively unrolls the
    SCL decoding tree, thus, for larger values of ``n`` building the
    decoding graph can become time consuming. Please consider the
    ``cpu_only`` option if building the graph takes to long.

    A hybrid SC/SCL decoder as proposed in [Cammerer_Hybrid_SCL]_ (using SC
    instead of BP) can be activated with option ``use_hybrid_sc`` iff an
    outer CRC is available. Please note that the results are not exactly
    SCL performance caused by the false positive rate of the CRC.

    As commonly done, we assume frozen bits are set to `0`. Please note
    that - although its practical relevance is only little - setting frozen
    bits to `1` may result in `affine` codes instead of linear code as the
    `all-zero` codeword is not necessarily part of the code any more.
    """

    def __init__(self,
                 frozen_pos,
                 n,
                 list_size=8,
                 crc_degree=None,
                 use_hybrid_sc=False,
                 use_fast_scl=True,
                 cpu_only=False,
                 use_scatter=False,
                 ind_iil_inv=None,
                 return_crc_status=False,
                 precision=None,
                 **kwargs):

        super().__init__(precision=precision, **kwargs)

        # assert error if r>1 or k, n are negative
        if not isinstance(n, numbers.Number):
            raise TypeError("n must be a number.")
        n = int(n) # n can be float (e.g. as result of n=k*r)
        if not isinstance(list_size, int):
            raise TypeError("list_size must be integer.")
        if not isinstance(cpu_only, bool):
            raise TypeError("cpu_only must be bool.")
        if not isinstance(use_scatter, bool):
            raise TypeError("use_scatter must be bool.")
        if not isinstance(use_fast_scl, bool):
            raise TypeError("use_fast_scl must be bool.")
        if not isinstance(use_hybrid_sc, bool):
            raise TypeError("use_hybrid_sc must be bool.")
        if not isinstance(return_crc_status, bool):
            raise TypeError("return_crc_status must be bool.")

        if not np.issubdtype(frozen_pos.dtype, int):
            raise TypeError("frozen_pos contains non int.")
        if len(frozen_pos)>n:
            msg = "Num. of elements in frozen_pos cannot be greater than n."
            raise ValueError(msg)
        if np.log2(n)!=int(np.log2(n)):
            raise ValueError("n must be a power of 2.")
        if np.log2(list_size)!=int(np.log2(list_size)):
            raise ValueError("list_size must be a power of 2.")

        # CPU mode is recommended for larger values of n
        if n>128 and cpu_only is False and use_hybrid_sc is False:
            warnings.warn("Required resource allocation is large " \
            "for the selected blocklength. Consider option `cpu_only=True`.")

        # CPU mode is recommended for larger values of L
        if list_size>32 and cpu_only is False and use_hybrid_sc is False:
            warnings.warn("Resource allocation is high for the " \
            "selected list_size. Consider option `cpu_only=True`.")

        # internal decoder parameters
        self._use_fast_scl = use_fast_scl # optimize rate-0 and rep nodes
        self._use_scatter = use_scatter # slower but more memory friendly
        self._cpu_only = cpu_only # run numpy decoder
        self._use_hybrid_sc = use_hybrid_sc

        # store internal attributes
        self._n = n
        self._frozen_pos = frozen_pos
        self._k = self._n - len(self._frozen_pos)
        self._list_size = list_size
        self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
        self._llr_max = 30. # internal max LLR value (not very critical for SC)
        if self._k!=len(self._info_pos):
            raise ArithmeticError("Internal error: invalid info_pos generated.")

        # create a frozen bit vector
        self._frozen_ind = np.zeros(self._n)
        self._frozen_ind[self._frozen_pos] = 1
        self._cw_ind = np.arange(self._n)
        self._n_stages = int(np.log2(self._n)) # number of decoding stages

        # init CRC check (if needed)
        if crc_degree is not None:
            self._use_crc = True
            self._crc_encoder = CRCEncoder(crc_degree, precision=precision)
            self._crc_decoder = CRCDecoder(self._crc_encoder,
                                           precision=precision)
            self._k_crc = self._crc_decoder.encoder.crc_length
        else:
            self._use_crc = False
            self._k_crc = 0
        if self._k<self._k_crc:
            msg = "Value of k is too small for given CRC_degree."
            raise ValueError(msg)

        if (crc_degree is None) and return_crc_status:
            self._return_crc_status = False
            raise ValueError("Returning CRC status requires given crc_degree.")
        else:
            self._return_crc_status = return_crc_status

        # store the inverse interleaver patter
        if ind_iil_inv is not None:
            if ind_iil_inv.shape[0]!=self._k:
                raise ValueError("ind_int must be of length k+k_crc.")
            self._ind_iil_inv = ind_iil_inv
            self._iil = True
        else:
            self._iil = False

        # use SC decoder first and use numpy-based SCL as "afterburner"
        if self._use_hybrid_sc:
            self._decoder_sc = PolarSCDecoder(frozen_pos, n,
                                              precision=precision)
            # Note: CRC required to detect SC success
            if not self._use_crc:
                raise ValueError("Hybrid SC requires outer CRC.")

    ###############################
    # Public methods and properties
    ###############################

    @property
    def n(self):
        """Codeword length"""
        return self._n

    @property
    def k(self):
        """Number of information bits"""
        return self._k

    @property
    def k_crc(self):
        """Number of CRC bits"""
        return self._k_crc

    @property
    def frozen_pos(self):
        """Frozen positions for Polar decoding"""
        return self._frozen_pos

    @property
    def info_pos(self):
        """Information bit positions for Polar encoding"""
        return self._info_pos

    @property
    def llr_max(self):
        """Maximum LLR value for internal calculations"""
        return self._llr_max

    @property
    def list_size(self):
        """List size for SCL decoding"""
        return self._list_size

    #####################################
    # Helper functions for the TF decoder
    #####################################

    def _update_rate0_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
        """Update rate-0 sub-code (i.e., all frozen) at pos ``cw_ind``.

        See eq. (26) in [Hashemi_SSCL]_.

        Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
        initialized with `0` already.
        """
        n = len(cw_ind)
        stage_ind = int(np.log2(n))

        llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
        llr_in = tf.clip_by_value(llr,
                                  clip_value_min=-self._llr_max,
                                  clip_value_max=self._llr_max)

        # update path metric for complete sub-block of length n
        pm_val = tf.math.softplus(-1.*llr_in)
        msg_pm += tf.reduce_sum(pm_val, axis=-1)

        return msg_pm, msg_uhat, msg_llr

    def _update_rep_code(self, msg_pm, msg_uhat, msg_llr, cw_ind):
        """Update rep. code (i.e., only rightmost bit is non-frozen)
        sub-code at position ``ind_u``.

        See Eq. (31) in [Hashemi_SSCL]_.

        Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
        initialized with `0` already.
        """
        n = len(cw_ind)
        stage_ind = int(np.log2(n))

        # update PM
        llr = tf.gather(msg_llr[:, :, stage_ind, :], cw_ind, axis=2)
        llr_in = tf.clip_by_value(llr,
                                  clip_value_min=-self._llr_max,
                                  clip_value_max=self._llr_max)

        # upper branch has negative llr values (bit is 1)
        llr_low =  llr_in[:, :self._list_size, :]
        llr_up = - llr_in[:, self._list_size:, :]
        llr_pm = tf.concat([llr_low, llr_up], 1)
        pm_val = tf.math.softplus(-1.*llr_pm)
        msg_pm += tf.reduce_sum(pm_val, axis=-1)

        msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
        msg_uhat21 = tf.expand_dims(
                        msg_uhat[:, self._list_size:, stage_ind, :cw_ind[0]],
                        axis=2)

        msg_uhat22= tf.expand_dims(
                        msg_uhat[:, self._list_size:, stage_ind, cw_ind[-1]+1:],
                        axis=2)
        # ones to insert
        msg_ones = tf.ones([tf.shape(msg_uhat)[0], self._list_size, 1, n],
                            self.rdtype)

        msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
        msg_uhat24_1 = msg_uhat[:, self._list_size:, :stage_ind, :]
        msg_uhat24_2 = msg_uhat[:, self._list_size:, stage_ind+1:, :]

        msg_uhat2 = tf.concat([msg_uhat24_1, msg_uhat23, msg_uhat24_2], 2)
        msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)

        # branch last bit and update pm at pos cw_ind[-1]
        msg_uhat = self._update_single_bit([cw_ind[-1]], msg_uhat)
        msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
                                                        msg_uhat,
                                                        msg_llr)
        msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
                                                          msg_llr,
                                                          msg_pm)
        return msg_pm, msg_uhat, msg_llr

    def _update_single_bit(self, ind_u, msg_uhat):
        """Update single bit at position ``ind_u`` for all decoders.

        Remark: bits are not explicitly set to `0` as ``msg_uhat`` is
        initialized with `0` already.

        Remark: Two versions are implemented (throughput vs. graph complexity):
        1.) use tensor_scatter_nd_update
        2.) explicitly split graph and concatenate again
        """
        # position is non-frozen
        if self._frozen_ind[ind_u[0]]==0:

            # msg_uhat[:, ind_up, 0, ind_u] = 1
            if self._use_scatter:
                ind_dec = np.arange(self._list_size, 2*self._list_size, 1)
                ind_stage = np.array([0])

                # transpose such that batch dim can be broadcasted
                msg_uhat_t = tf.transpose(msg_uhat, [1, 3, 2, 0])

                # generate index grid
                ind_u = tf.cast(ind_u, tf.int64)
                grid = tf.meshgrid(ind_dec, ind_u, ind_stage)
                ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 3])

                updates = tf.ones([ind.shape[0], tf.shape(msg_uhat)[0]],
                                  self.rdtype)
                msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
                                                         ind,
                                                         updates)
                # and restore original order
                msg_uhat = tf.transpose(msg_uhat_s, [3, 0, 2, 1])
            else:
                # alternative solution with split/concatenation of graph
                msg_uhat1 = msg_uhat[:, :self._list_size, :, :]
                msg_uhat21 = tf.expand_dims(
                                msg_uhat[:, self._list_size:, 0, :ind_u[0]],
                                axis=2)

                msg_uhat22= tf.expand_dims(
                                msg_uhat[:, self._list_size:, 0, ind_u[0]+1:],
                                axis=2)
                # ones to insert
                msg_ones = tf.ones_like(tf.reshape(
                                msg_uhat[:, self._list_size:, 0, ind_u[0]],
                                [-1, self._list_size, 1, 1]))

                msg_uhat23 = tf.concat([msg_uhat21, msg_ones, msg_uhat22], 3)
                msg_uhat24 = msg_uhat[:, self._list_size:, 1:, :]

                msg_uhat2 = tf.concat([msg_uhat23, msg_uhat24], 2)
                msg_uhat = tf.concat([msg_uhat1, msg_uhat2], 1)

        return msg_uhat

    def _update_pm(self, ind_u, msg_uhat, msg_llr, msg_pm):
        """Update path metric of all decoders after updating bit_pos ``ind_u``.

        We implement (10) from [Stimming_LLR]_.
        """
        u_hat = msg_uhat[:, :, 0, ind_u[0]]
        llr = msg_llr[:, :, 0, ind_u[0]]

        llr_in = tf.clip_by_value(llr,
                                  clip_value_min=-self._llr_max,
                                  clip_value_max=self._llr_max)

        # Numerically more stable implementation of log(1 + exp(-x))
        msg_pm += tf.math.softplus(-tf.multiply((1 - 2*u_hat), llr_in))
        return msg_pm

    def _sort_decoders(self, msg_pm, msg_uhat, msg_llr):
        """Sort decoders according to their path metric."""

        ind = tf.argsort(msg_pm, axis=-1)

        msg_pm = tf.gather(msg_pm, ind, batch_dims=1, axis=None)
        msg_uhat = tf.gather(msg_uhat, ind, batch_dims=1, axis=None)
        msg_llr = tf.gather(msg_llr, ind, batch_dims=1, axis=None)

        return msg_pm, msg_uhat, msg_llr

    def _cn_op(self, x, y):
        """Check-node update (boxplus) for LLR inputs.

        Operations are performed element-wise.

        See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
        """
        x_in = tf.clip_by_value(x,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)
        y_in = tf.clip_by_value(y,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)

        # Avoid division for numerical stability
        # Implements log(1+e^(x+y))
        llr_out = tf.math.softplus((x_in + y_in))
        # Implements log(e^x+e^y)
        llr_out -= tf.math.reduce_logsumexp(tf.stack([x_in, y_in], axis=-1),
                                            axis=-1)

        return llr_out

    def _vn_op(self, x, y, u_hat):
        """Variable node update for LLR inputs.

        Operations are performed element-wise.

        See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
        """
        return tf.multiply((1 - 2*u_hat), x) + y

    def _duplicate_paths(self, msg_uhat, msg_llr, msg_pm):
        """Duplicate paths by copying the upper branch into the lower one.
        """
        msg_uhat = tf.tile(msg_uhat[:, :self._list_size, :, :], [1, 2, 1, 1])
        msg_llr = tf.tile(msg_llr[:, :self._list_size, :, :], [1, 2, 1, 1])
        msg_pm = tf.tile(msg_pm[:, :self._list_size], [1, 2])

        return msg_uhat, msg_llr, msg_pm

    def _update_left_branch(self, msg_llr, stage_ind, cw_ind_left,cw_ind_right):
        """Update messages of left branch.

        Remark: Two versions are implemented (throughput vs. graph complexity):
        1.) use tensor_scatter_nd_update
        2.) explicitly split graph and concatenate again
        """

        llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
                                cw_ind_left,
                                axis=2)
        llr_right_in = tf.gather(msg_llr[:, :, stage_ind, :],
                                 cw_ind_right,
                                 axis=2)

        llr_left_out = self._cn_op(llr_left_in, llr_right_in)

        if self._use_scatter:
            # self.msg_llr[:, :, stage_ind-1, cw_ind_left] = llr_left_out

            # transpose such that batch-dim can be broadcasted
            msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
            llr_left_out_s = tf.transpose(llr_left_out, [2, 1, 0])

            # generate index grid
            stage_ind = tf.cast(stage_ind, tf.int64)
            cw_ind_left = tf.cast(cw_ind_left, tf.int64)
            grid = tf.meshgrid(stage_ind-1, cw_ind_left)
            ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])

            # update values
            msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
                                                    ind,
                                                    llr_left_out_s)

            # and restore original order
            msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
        else:
            # alternative solution with split/concatenation of graph
            # llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
            llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                  np.arange(0, cw_ind_left[0]),
                                  axis=2)

            llr_right = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                  cw_ind_right,
                                  axis=2)
            llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                   np.arange(cw_ind_right[-1] +1, self._n),
                                   axis=2)

            llr_s = tf.concat([llr_left0,
                               llr_left_out,
                               llr_right,
                               llr_right1], 2)

            llr_s = tf.expand_dims(llr_s, axis=2)

            msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
            msg_llr2 = msg_llr[:, :, stage_ind:, :]
            msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)

        return msg_llr

    def _update_right_branch(self, msg_llr, msg_uhat, stage_ind, cw_ind_left,
                             cw_ind_right):
        """Update messages for right branch.

        Remark: Two versions are implemented (throughput vs. graph complexity):
        1.) use tensor_scatter_nd_update
        2.) explicitly split graph and concatenate again
        """
        u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
                                  cw_ind_left,
                                  axis=2)

        llr_left_in = tf.gather(msg_llr[:, :, stage_ind, :],
                                cw_ind_left,
                                axis=2)

        llr_right = tf.gather(msg_llr[:, :, stage_ind, :],
                              cw_ind_right,
                              axis=2)

        llr_right_out = self._vn_op(llr_left_in, llr_right, u_hat_left_up)

        if self._use_scatter:
            # transpose such that batch dim can be broadcasted
            msg_llr_t = tf.transpose(msg_llr, [2, 3, 1, 0])
            llr_right_out_s = tf.transpose(llr_right_out, [2, 1, 0])

            # generate index grid
            stage_ind = tf.cast(stage_ind, tf.int64)
            cw_ind_left = tf.cast(cw_ind_right, tf.int64)
            grid = tf.meshgrid(stage_ind-1, cw_ind_right)
            ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])

            msg_llr_s = tf.tensor_scatter_nd_update(msg_llr_t,
                                                    ind,
                                                    llr_right_out_s)

            # and restore original order
            msg_llr = tf.transpose(msg_llr_s, [3, 2, 0, 1])
        else:
            # alternative solution with split/concatenation of graph
            # llr_left = msg_llr[:, :, stage_ind, cw_ind_left]
            llr_left0 = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                  np.arange(0, cw_ind_left[0]),
                                  axis=2)
            llr_left = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                 cw_ind_left,
                                 axis=2)
            llr_right1 = tf.gather(msg_llr[:, :, stage_ind-1, :],
                                   np.arange(cw_ind_right[-1]+1, self._n),
                                   axis=2)

            llr_s = tf.concat([llr_left0, llr_left, llr_right_out,llr_right1],2)
            llr_s = tf.expand_dims(llr_s, axis=2)

            msg_llr1 = msg_llr[:, :, 0:stage_ind-1, :]
            msg_llr2 = msg_llr[:, :, stage_ind:, :]

            msg_llr = tf.concat([msg_llr1, llr_s, msg_llr2], 2)

        return msg_llr

    def _update_branch_u(self, msg_uhat, stage_ind, cw_ind_left, cw_ind_right):
        """Update ``u_hat`` messages after executing both branches.

        Remark: Two versions are implemented (throughput vs. graph complexity):
        1.) use tensor_scatter_nd_update
        2.) explicitly split graph and concatenate again
        """
        u_hat_left_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
                                  cw_ind_left,
                                  axis=2)

        u_hat_right_up = tf.gather(msg_uhat[:, :, stage_ind-1, :],
                                   cw_ind_right,
                                   axis=2)

        # combine u_hat via bitwise_xor (more efficient than mod2)
        u_hat_left_up_int = tf.cast(u_hat_left_up, tf.int32)
        u_hat_right_up_int = tf.cast(u_hat_right_up, tf.int32)
        u_hat_left = tf.bitwise.bitwise_xor(u_hat_left_up_int,
                                            u_hat_right_up_int)
        u_hat_left = tf.cast(u_hat_left, self.rdtype)

        if self._use_scatter:
            cw_ind = np.concatenate([cw_ind_left, cw_ind_right])

            u_hat = tf.concat([u_hat_left, u_hat_right_up], -1)

            # self.msg_llr[:, stage_ind-1, cw_ind_left] = llr_left_out

            # transpose such that batch dim can be broadcasted
            msg_uhat_t = tf.transpose(msg_uhat, [2, 3, 1, 0])
            u_hat_s = tf.transpose(u_hat, [2, 1, 0])

            # generate index grid
            stage_ind = tf.cast(stage_ind, tf.int64)
            cw_ind = tf.cast(cw_ind, tf.int64)
            grid = tf.meshgrid(stage_ind, cw_ind)
            ind = tf.reshape(tf.stack(grid, axis=-1), [-1, 2])

            msg_uhat_s = tf.tensor_scatter_nd_update(msg_uhat_t,
                                                     ind,
                                                     u_hat_s)

            # and restore original order
            msg_uhat = tf.transpose(msg_uhat_s, [3, 2, 0, 1])
        else:
            # alternative solution with split/concatenation of graph
            u_hat_left_0 = tf.gather(msg_uhat[:, :, stage_ind, :],
                                     np.arange(0, cw_ind_left[0]),
                                     axis=2)
            u_hat_right_1 = tf.gather(msg_uhat[:, :, stage_ind, :],
                                      np.arange(cw_ind_right[-1]+1, self._n),
                                      axis=2)

            u_hat = tf.concat([u_hat_left_0,
                               u_hat_left,
                               u_hat_right_up,
                               u_hat_right_1], 2)

            # provide u_hat for next higher stage
            msg_uhat1 = msg_uhat[:, :, 0:stage_ind, :]
            msg_uhat2 = msg_uhat[:, :, stage_ind+1:, :]
            u_hat = tf.expand_dims(u_hat, axis=2)

            msg_uhat = tf.concat([msg_uhat1, u_hat, msg_uhat2], 2)

        return msg_uhat

    def _polar_decode_scl(self, cw_ind, msg_uhat, msg_llr, msg_pm):
        """Recursive decoding function for SCL decoding.

        We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
        and branch the messages into a `left` and `right` update paths until
        reaching a leaf node.

        Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
        tree depth while maintaining the same output.
        """
        # current sub-code length and stage index (= tree depth)
        n = len(cw_ind)
        stage_ind = int(np.log2(n))

        # recursively branch through decoding tree
        if n>1:
            # prune tree if rate-0 subcode is detected
            if self._use_fast_scl:
                if np.sum(self._frozen_ind[cw_ind])==n:
                    msg_pm, msg_uhat, msg_llr = self._update_rate0_code(msg_pm,
                                                                       msg_uhat,
                                                                       msg_llr,
                                                                       cw_ind)
                    return msg_uhat, msg_llr, msg_pm

                if (self._frozen_ind[cw_ind[-1]]==0 and
                    np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
                    msg_pm, msg_uhat, msg_llr, = self._update_rep_code(msg_pm,
                                                                       msg_uhat,
                                                                       msg_llr,
                                                                       cw_ind)
                    return msg_uhat, msg_llr, msg_pm

            # split index into left and right part
            cw_ind_left = cw_ind[0:int(n/2)]
            cw_ind_right = cw_ind[int(n/2):]

            # ----- left branch -----
            msg_llr = self. _update_left_branch(msg_llr,
                                                stage_ind,
                                                cw_ind_left,
                                                cw_ind_right)

            # call sub-graph decoder of left branch
            msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_left,
                                                               msg_uhat,
                                                               msg_llr,
                                                               msg_pm)

            # ----- right branch -----
            msg_llr = self._update_right_branch(msg_llr,
                                                msg_uhat,
                                                stage_ind,
                                                cw_ind_left,
                                                cw_ind_right)

            # call sub-graph decoder of right branch
            msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(cw_ind_right,
                                                               msg_uhat,
                                                               msg_llr,
                                                               msg_pm)
            # update uhat at current stage
            msg_uhat = self._update_branch_u(msg_uhat,
                                             stage_ind,
                                             cw_ind_left,
                                             cw_ind_right)

        # if leaf is reached perform basic decoding op (=decision)
        else:
            # update bit value at current position
            msg_uhat = self._update_single_bit(cw_ind, msg_uhat)

            # update PM
            msg_pm = self._update_pm(cw_ind, msg_uhat, msg_llr, msg_pm)

            if self._frozen_ind[cw_ind]==0: # position is non-frozen
                # sort list
                msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
                                                                msg_uhat,
                                                                msg_llr)

                # duplicate l best decoders to pos l:2*l (kill other decoders)
                msg_uhat, msg_llr, msg_pm = self._duplicate_paths(msg_uhat,
                                                                  msg_llr,
                                                                  msg_pm)

        return msg_uhat, msg_llr, msg_pm

    def _decode_tf(self, llr_ch):
        """Main decoding function in TF.

        Initializes memory and calls recursive decoding function.
        """

        batch_size = tf.shape(llr_ch)[0]

        # allocate memory for all 2*list_size decoders
        msg_uhat = tf.zeros([batch_size, 2*self._list_size,
                             self._n_stages+1, self._n],
                             self.rdtype)
        msg_llr = tf.zeros([batch_size, 2*self._list_size,
                            self._n_stages, self._n],
                            self.rdtype)
        # init all 2*l decoders with same llr_ch
        llr_ch = tf.reshape(llr_ch, [-1, 1, 1, self._n])
        llr_ch = tf.tile(llr_ch,[1, 2*self._list_size, 1, 1])

        # init last stage with llr_ch
        msg_llr = tf.concat([msg_llr, llr_ch], 2)

        # init all remaining L-1 decoders with high penalty
        pm0 = tf.zeros([batch_size, 1], self.rdtype)
        pm1 = self._llr_max * tf.ones([batch_size, self._list_size-1],
                                      self.rdtype)
        msg_pm = tf.concat([pm0, pm1, pm0, pm1], 1)

        # and call recursive graph function
        msg_uhat, msg_llr, msg_pm = self._polar_decode_scl(self._cw_ind,
                                                           msg_uhat,
                                                           msg_llr,
                                                           msg_pm)

        # and sort output
        msg_pm, msg_uhat, msg_llr = self._sort_decoders(msg_pm,
                                                        msg_uhat,
                                                        msg_llr)
        return [msg_uhat, msg_pm]

    ####################################
    # Helper functions for Numpy decoder
    ####################################

    def _update_rate0_code_np(self, cw_ind):
        """Update rate-0 (i.e., all frozen) sub-code at pos ``cw_ind`` in Numpy.

        See Eq. (26) in [Hashemi_SSCL]_.
        """
        n = len(cw_ind)
        stage_ind = int(np.log2(n))

        # update PM for each batch sample
        ind = np.expand_dims(self._dec_pointer, axis=-1)
        llr_in = np.take_along_axis(self.msg_llr[:, :, stage_ind, cw_ind],
                                    ind,
                                    axis=1)

        llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
        pm_val = np.log(1 + np.exp(-llr_clip))
        self.msg_pm += np.sum(pm_val, axis=-1)

    def _update_rep_code_np(self, cw_ind):
        """Update rep. code (i.e., only rightmost bit is non-frozen)
        sub-code at position ``ind_u`` in Numpy.

        See Eq. (31) in [Hashemi_SSCL]_.
        """
        n = len(cw_ind)
        stage_ind = int(np.log2(n))
        bs = self._dec_pointer.shape[0]

        # update PM
        llr = np.zeros([bs, 2*self._list_size, n])
        for i in range(bs):
            llr_i = self.msg_llr[i, self._dec_pointer[i, :], stage_ind, :]
            llr[i, :, :] = llr_i[:, cw_ind]

        # upper branch has negative llr values (bit is 1)
        llr[:, self._list_size:, :] = - llr[:, self._list_size:, :]
        llr_in = np.maximum(np.minimum(llr, self._llr_max), -self._llr_max)
        pm_val = np.sum(np.log(1 + np.exp(-llr_in)), axis=-1)
        self.msg_pm += pm_val

        for i in range(bs):
            ind_dec = self._dec_pointer[i, self._list_size:]
            for j in cw_ind:
                self.msg_uhat[i, ind_dec, stage_ind, j] = 1

        # branch last bit and update pm at pos cw_ind[-1]
        self._update_single_bit_np([cw_ind[-1]])
        self._sort_decoders_np()
        self._duplicate_paths_np()

    def _update_single_bit_np(self, ind_u):
        """Update single bit at position ``ind_u`` of all decoders in Numpy."""

        if self._frozen_ind[ind_u]==0: # position is non-frozen
            ind_dec = np.expand_dims(self._dec_pointer[:, self._list_size:],
                                     axis=-1)
            uhat_slice = self.msg_uhat[:, :, 0, ind_u]
            np.put_along_axis(uhat_slice, ind_dec, 1., axis=1)
            self.msg_uhat[:, :, 0, ind_u] = uhat_slice


    def _update_pm_np(self, ind_u):
        """ Update path metric of all decoders at bit position ``ind_u`` in
        Numpy.

        We apply Eq. (10) from [Stimming_LLR]_.
        """
        ind = np.expand_dims(self._dec_pointer, axis=-1)
        u_hat = np.take_along_axis(self.msg_uhat[:, :, 0, ind_u], ind, axis=1)
        u_hat = np.squeeze(u_hat, axis=-1)
        llr_in = np.take_along_axis(self.msg_llr[:, :, 0, ind_u], ind, axis=1)
        llr_in = np.squeeze(llr_in, axis=-1)

        llr_clip = np.maximum(np.minimum(llr_in, self._llr_max), -self._llr_max)
        self.msg_pm += np.log(1 + np.exp(-np.multiply((1-2*u_hat), llr_clip)))

    def _sort_decoders_np(self):
        """Sort decoders according to their path metric."""

        ind = np.argsort(self.msg_pm, axis=-1)
        self.msg_pm = np.take_along_axis(self.msg_pm, ind, axis=1)
        self._dec_pointer = np.take_along_axis(self._dec_pointer, ind, axis=1)

    def _cn_op_np(self, x, y):
        """Check node update (boxplus) for LLRs in Numpy.

        See [Stimming_LLR]_ and [Hashemi_SSCL]_ for detailed equations.
        """
        x_in = np.maximum(np.minimum(x, self._llr_max), -self._llr_max)
        y_in = np.maximum(np.minimum(y, self._llr_max), -self._llr_max)

        # avoid division for numerical stability
        llr_out = np.log(1 + np.exp(x_in + y_in))
        llr_out -= np.log(np.exp(x_in) + np.exp(y_in))

        return llr_out

    def _vn_op_np(self, x, y, u_hat):
        """Variable node update (boxplus) for LLRs in Numpy."""
        return np.multiply((1-2*u_hat), x) + y

    def _duplicate_paths_np(self):
        """Copy first ``list_size``/2 paths into lower part in Numpy.

        Decoder indices are encoded in ``self._dec_pointer``.
        """
        ind_low = self._dec_pointer[:, :self._list_size]
        ind_up = self._dec_pointer[:, self._list_size:]

        for i in range(ind_up.shape[0]):
            self.msg_uhat[i, ind_up[i,:], :, :] = self.msg_uhat[i,
                                                                ind_low[i,:],
                                                                :, :]
            self.msg_llr[i, ind_up[i,:],:,:] = self.msg_llr[i, ind_low[i,:],:,:]

        # pm must be sorted directly (not accessed via pointer)
        self.msg_pm[:, self._list_size:] = self.msg_pm[:, :self._list_size]

    def _polar_decode_scl_np(self, cw_ind):
        """Recursive decoding function in Numpy.

        We follow the terminology from [Hashemi_SSCL]_ and [Stimming_LLR]_
        and branch the messages into a `left` and `right` update paths until
        reaching a leaf node.

        Tree pruning as proposed in [Hashemi_SSCL]_ is used to minimize the
        tree depth while maintaining the same output.
        """
        n = len(cw_ind)
        stage_ind = int(np.log2(n))

        # recursively branch through decoding tree
        if n>1:
            # prune tree if rate-0 subcode or rep-code is detected
            if self._use_fast_scl:
                if np.sum(self._frozen_ind[cw_ind])==n:
                    # rate0 code detected
                    self._update_rate0_code_np(cw_ind)
                    return
                if (self._frozen_ind[cw_ind[-1]]==0 and
                    np.sum(self._frozen_ind[cw_ind[:-1]])==n-1):
                    # rep code detected
                    self._update_rep_code_np(cw_ind)
                    return
            cw_ind_left = cw_ind[0:int(n/2)]
            cw_ind_right = cw_ind[int(n/2):]

            # ----- left branch -----
            llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
            llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]

            self.msg_llr[:, :, stage_ind-1, cw_ind_left] = self._cn_op_np(
                                                                    llr_left,
                                                                    llr_right)

            # call left branch decoder
            self._polar_decode_scl_np(cw_ind_left)

            # ----- right branch -----
            u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
            llr_left = self.msg_llr[:, :, stage_ind, cw_ind_left]
            llr_right = self.msg_llr[:, :, stage_ind, cw_ind_right]

            self.msg_llr[:, :, stage_ind-1, cw_ind_right] = self._vn_op_np(
                                                                llr_left,
                                                                llr_right,
                                                                u_hat_left_up)

            # call right branch decoder
            self._polar_decode_scl_np(cw_ind_right)

            # combine u_hat
            u_hat_left_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_left]
            u_hat_right_up = self.msg_uhat[:, :, stage_ind-1, cw_ind_right]

            # u_hat_left_up XOR u_hat_right_up
            u_hat_left =  (u_hat_left_up != u_hat_right_up) + 0

            u_hat = np.concatenate([u_hat_left, u_hat_right_up], axis=-1)

            # provide u_hat for next higher stage
            self.msg_uhat[:, :, stage_ind,  cw_ind] = u_hat

        else: # if leaf is reached perform basic decoding op (=decision)

            self._update_single_bit_np(cw_ind)

            # update PM
            self._update_pm_np(cw_ind)

            # position is non-frozen
            if self._frozen_ind[cw_ind]==0:
                # sort list
                self._sort_decoders_np()
                # duplicate the best list_size decoders
                self._duplicate_paths_np()
        return

    def _decode_np_batch(self, llr_ch):
        """Decode batch of ``llr_ch`` with Numpy decoder."""

        bs = llr_ch.shape[0]

        # allocate memory for all 2*list_size decoders
        self.msg_uhat = np.zeros([bs,
                                  2*self._list_size,
                                  self._n_stages+1,
                                  self._n])
        self.msg_llr = np.zeros([bs,
                                 2*self._list_size,
                                 self._n_stages+1,
                                 self._n])
        self.msg_pm = np.zeros([bs,
                                2*self._list_size])

        # L-1 decoders start with high penalty
        self.msg_pm[:,1:self._list_size] = self._llr_max
        # same for the second half of the L-1 decoders
        self.msg_pm[:,self._list_size+1:] = self._llr_max

        # use pointers to avoid in-memory sorting
        self._dec_pointer = np.arange(2*self._list_size)
        self._dec_pointer = np.tile(np.expand_dims(self._dec_pointer, axis=0),
                                    [bs,1])

        # init llr_ch (broadcast via list dimension)
        self.msg_llr[:, :, self._n_stages, :] = np.expand_dims(llr_ch, axis=1)

        # call recursive graph function
        self._polar_decode_scl_np(self._cw_ind)

        # select most likely candidate
        self._sort_decoders_np()

        # remove pointers
        for ind in range(bs):
            self.msg_uhat[ind, :, :, :] = self.msg_uhat[ind,
                                                        self._dec_pointer[ind],
                                                        :, :]
        return self.msg_uhat, self.msg_pm

    def _decode_np_hybrid(self, llr_ch, u_hat_sc, crc_valid):
        """Hybrid SCL decoding stage that decodes iff CRC from previous SC
        decoding attempt failed.

        This option avoids the usage of the high-complexity SCL decoder in cases
        where SC would be sufficient. For further details we refer to
        [Cammerer_Hybrid_SCL]_ (we use SC instead of the proposed BP stage).

        Remark: This decoder does not exactly implement SCL as the CRC
        can be false positive after the SC stage. However, in these cases
        SCL+CRC may also yield the wrong results.

        Remark 2: Due to the excessive control flow (if/else) and the
        varying batch-sizes, this function is only available as Numpy
        decoder (i.e., runs on the CPU).
        """

        bs = llr_ch.shape[0]
        crc_valid = np.squeeze(crc_valid, axis=-1)
        # index of codewords that need SCL decoding
        ind_invalid = np.arange(bs)[np.invert(crc_valid)]

        # init SCL decoder for bs_hyb samples requiring SCL dec.
        llr_ch_hyb = np.take(llr_ch, ind_invalid, axis=0)
        msg_uhat_hyb, msg_pm_hyb = self._decode_np_batch(llr_ch_hyb)

        # merge results with previously decoded SC results
        msg_uhat = np.zeros([bs, 2*self._list_size, 1, self._n])
        msg_pm = np.ones([bs, 2*self._list_size]) * self._llr_max * self.k
        msg_pm[:, 0] = 0

        # copy SC data
        msg_uhat[:, 0, 0, self._info_pos] = u_hat_sc

        ind_hyb = 0
        for ind in range(bs):
            if not crc_valid[ind]:
                #copy data from SCL
                msg_uhat[ind, :, 0, :] = msg_uhat_hyb[ind_hyb, :, 0, :]
                msg_pm[ind, :] = msg_pm_hyb[ind_hyb, :]
                ind_hyb += 1

        return msg_uhat, msg_pm

    ########################
    # Sionna Block functions
    ########################

    def build(self, input_shape):
        """Build and check if shape of input is invalid."""
        if input_shape[-1]!=self._n:
            raise ValueError("Invalid input shape.")

    def call(self, llr_ch):
        """Successive cancellation list (SCL) decoding function.

        This function performs successive cancellation list decoding
        and returns the estimated information bits.

        An outer CRC can be applied optionally by setting ``crc_degree``.

        Args:
            llr_ch (tf.float): Tensor of shape `[...,n]` containing the
                channel LLR values (as logits).

        Returns:
            `tf.float`: Tensor of shape `[...,k]` containing
            hard-decided estimations of all ``k`` information bits.

        Note:
        This function recursively unrolls the SCL decoding tree, thus,
        for larger values of ``n`` building the decoding graph can become
        time consuming. Please consider the ``cpu_only`` option instead.
        """

        input_shape = llr_ch.shape
        new_shape = [-1, self._n]
        llr_ch = tf.reshape(llr_ch, new_shape)

        llr_ch = -1. * llr_ch # logits are converted into "true" llrs

        # if activated use Numpy decoder
        if self._use_hybrid_sc:
            # use SC decoder to decode first
            u_hat = self._decoder_sc(-llr_ch)
            _, crc_valid = self._crc_decoder(u_hat)
            msg_uhat, msg_pm = tf.py_function(func=self._decode_np_hybrid,
                                              inp=[llr_ch, u_hat, crc_valid],
                                              Tout=[self.rdtype, self.rdtype])
            # note: return shape is only 1 in 3. dim (to avoid copy overhead)
            msg_uhat = tf.reshape(msg_uhat, [-1, 2*self._list_size, 1, self._n])
            msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
        else:
            if self._cpu_only:
                msg_uhat, msg_pm = tf.py_function(func=self._decode_np_batch,
                                                inp=[llr_ch],
                                                Tout=[self.rdtype, self.rdtype])
                # restore shape information
                msg_uhat = tf.reshape(msg_uhat,
                            [-1, 2*self._list_size, self._n_stages+1, self._n])
                msg_pm = tf.reshape(msg_pm, [-1, 2*self._list_size])
            else:
                msg_uhat, msg_pm = self._decode_tf(llr_ch)

        # check CRC (and remove CRC parity bits)
        if self._use_crc:
            u_hat_list = tf.gather(msg_uhat[:, :, 0, :],
                                   self._info_pos,
                                   axis=-1)
            # undo input bit interleaving
            # remark: the output is not interleaved for compatibility with SC
            if self._iil:
                u_hat_list_crc = tf.gather(u_hat_list,
                                           self._ind_iil_inv,
                                           axis=-1)
            else: # no interleaving applied
                u_hat_list_crc = u_hat_list

            _, crc_valid = self._crc_decoder(u_hat_list_crc)
            # add penalty to pm if CRC fails
            pm_penalty = ((1. - tf.cast(crc_valid, self.rdtype))
                       * self._llr_max * self.k)
            msg_pm += tf.squeeze(pm_penalty, axis=2)

        # select most likely candidate
        cand_ind = tf.argmin(msg_pm, axis=-1)
        c_hat = tf.gather(msg_uhat[:, :, 0, :], cand_ind, axis=1, batch_dims=1)
        u_hat = tf.gather(c_hat, self._info_pos, axis=-1)

        # and reconstruct input shape
        output_shape = input_shape.as_list()
        output_shape[-1] = self.k
        output_shape[0] = -1 # first dim can be dynamic (None)
        u_hat_reshape = tf.reshape(u_hat, output_shape)

        if self._return_crc_status:
            # reconstruct CRC status
            crc_status = tf.gather(crc_valid, cand_ind, axis=1, batch_dims=1)
            # reconstruct shape
            output_shape.pop() # remove last dimension
            crc_status = tf.reshape(crc_status, output_shape)

            # return info bits and CRC status
            return u_hat_reshape, crc_status
        else: # return only info bits
            return u_hat_reshape


class PolarBPDecoder(Block):
    # pylint: disable=line-too-long
    """Belief propagation (BP) decoder for Polar codes [Arikan_Polar]_ and
    Polar-like codes based on [Arikan_BP]_ and [Forney_Graphs]_.

    Remark: The PolarBPDecoder does currently not support XLA.

    Parameters
    ----------
    frozen_pos: ndarray
        Array of `int` defining the ``n-k`` indices of the frozen positions.

    n: int
        Defining the codeword length.

    num_iter: int
        Defining the number of decoder iterations (no early stopping used
        at the moment).

    hard_out: `bool`, (default `True`)
        If `True`,  the decoder provides hard-decided
        information bits instead of soft-values.

    precision : `None` (default) | 'single' | 'double'
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.

    Input
    -----
    llr_ch: [...,n], tf.float32
        Tensor containing the channel logits/llr values.

    Output
    ------
    : [...,k], tf.float32
        Tensor containing bit-wise soft-estimates
        (or hard-decided bit-values) of all ``k`` information bits.

    Note
    ----
    This decoder is fully differentiable and, thus, well-suited for
    gradient descent-based learning tasks such as `learned code design`
    [Ebada_Design]_.

    As commonly done, we assume frozen bits are set to `0`. Please note
    that - although its practical relevance is only little - setting frozen
    bits to `1` may result in `affine` codes instead of linear code as the
    `all-zero` codeword is not necessarily part of the code any more.
    """

    def __init__(self,
                 frozen_pos,
                 n,
                 num_iter=20,
                 hard_out=True,
                 precision=None,
                 **kwargs):

        super().__init__(precision=precision, **kwargs)

        # assert error if r>1 or k, n are negative
        if not isinstance(n, numbers.Number):
            raise TypeError("n must be a number.")
        n = int(n) # n can be float (e.g. as result of n=k*r)
        if not np.issubdtype(frozen_pos.dtype, int):
            raise TypeError("frozen_pos contains non int.")
        if len(frozen_pos)>n:
            msg = "Num. of elements in frozen_pos cannot be greater than n."
            raise ValueError(msg)
        if np.log2(n)!=int(np.log2(n)):
            raise ValueError("n must be a power of 2.")

        if not isinstance(hard_out, bool):
            raise TypeError("hard_out must be boolean.")

        # store internal attributes
        self._n = n
        self._frozen_pos = frozen_pos
        self._k = self._n - len(self._frozen_pos)
        self._info_pos = np.setdiff1d(np.arange(self._n), self._frozen_pos)
        if self._k!=len(self._info_pos):
            raise ArithmeticError("Internal error: invalid info_pos generated.")

        if not isinstance(num_iter, int):
            raise TypeError("num_iter must be integer.")
        if num_iter<=0:
            raise ValueError("num_iter must be a positive value.")
        self._num_iter = tf.constant(num_iter, dtype=tf.int32)

        self._llr_max = 19.3 # internal max LLR value
        self._hard_out = hard_out

        # depth of decoding graph
        self._n_stages = int(np.log2(self._n))

    ###############################
    # Public methods and properties
    ###############################

    @property
    def n(self):
        """Codeword length"""
        return self._n

    @property
    def k(self):
        """Number of information bits"""
        return self._k

    @property
    def frozen_pos(self):
        """Frozen positions for Polar decoding"""
        return self._frozen_pos

    @property
    def info_pos(self):
        """Information bit positions for Polar encoding"""
        return self._info_pos

    @property
    def llr_max(self):
        """Maximum LLR value for internal calculations"""
        return self._llr_max

    @property
    def num_iter(self):
        """Number of decoding iterations"""
        return self._num_iter

    @property
    def hard_out(self):
        """Indicates if decoder hard-decides outputs"""
        return self._hard_out

    @num_iter.setter
    def num_iter(self, num_iter):
        "Number of decoding iterations."
        if not isinstance(num_iter, int):
            raise ValueError('num_iter must be int.')
        if num_iter<0:
            raise ValueError('num_iter cannot be negative.')
        self._num_iter = tf.constant(num_iter, dtype=tf.int32)

    #################
    # Utility methods
    #################

    def _boxplus_tf(self, x, y):
        """Check-node update (boxplus) for LLR inputs.

        Operations are performed element-wise.
        """
        x_in = tf.clip_by_value(x,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)
        y_in = tf.clip_by_value(y,
                                clip_value_min=-self._llr_max,
                                clip_value_max=self._llr_max)

        # avoid division for numerical stability
        llr_out = tf.math.log(1 + tf.math.exp(x_in + y_in))
        llr_out -= tf.math.log(tf.math.exp(x_in) + tf.math.exp(y_in))

        return llr_out

    def _decode_bp(self, llr_ch, num_iter):
        """Iterative BP decoding function with LLR-values.

        Args:
            llr_ch (tf.float32): Tensor of shape `[batch_size, n]` containing
                the channel logits/llr values where `batch_size` denotes the
                batch-size.

            num_iter (int): Defining the number of decoder iteration
                (no early stopping used at the moment).
        Returns:
            `tf.float32`: Tensor of shape `[batch_size, k]` containing
            bit-wise soft-estimates (or hard-decided bit-values) of all
            information bits.
        """

        bs = tf.shape(llr_ch)[0]

        # store intermediate Tensors in TensorArray
        msg_l = tf.TensorArray(self.rdtype,
                               size=num_iter*(self._n_stages+1),
                               dynamic_size=False,
                               clear_after_read=False)

        msg_r = tf.TensorArray(self.rdtype,
                               size=num_iter*(self._n_stages+1),
                               dynamic_size=False,
                               clear_after_read=False)

        # init frozen positions with infinity
        msg_r_in = np.zeros([1, self._n])
        msg_r_in[:, self._frozen_pos] = self._llr_max
        # copy for all batch-samples
        msg_r_in = tf.tile(tf.constant(msg_r_in, self.rdtype), [bs, 1])
        msg_r_in = tf.cast(msg_r_in, self.rdtype)

        # perform decoding iterations
        for ind_it in tf.range(self._num_iter):
            # update left-to-right messages
            for ind_s in range(self._n_stages):
                # calc indices
                ind_range = np.arange(int(self._n/2))
                ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
                ind_2 = ind_1 + 2**ind_s
                # simplify gather with concatenated outputs
                ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))

                # load incoming l messages
                if ind_s==self._n_stages-1:
                    l1_in = tf.gather(llr_ch, ind_1, axis=1)
                    l2_in = tf.gather(llr_ch, ind_2, axis=1)
                elif ind_it==0:
                    l1_in = tf.zeros([bs, int(self._n/2)], self.rdtype)
                    l2_in = tf.zeros([bs, int(self._n/2)], self.rdtype)
                else:
                    l_in = msg_l.read((ind_s+1) + (ind_it-1)*(self._n_stages+1))
                    l1_in = tf.gather(l_in, ind_1, axis=1)
                    l2_in = tf.gather(l_in, ind_2, axis=1)

                # load incoming r messages
                if ind_s==0:
                    r1_in = tf.gather(msg_r_in, ind_1, axis=1)
                    r2_in = tf.gather(msg_r_in, ind_2, axis=1)
                else:
                    r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
                    r1_in = tf.gather(r_in, ind_1, axis=1)
                    r2_in = tf.gather(r_in, ind_2, axis=1)

                r1_out = self._boxplus_tf(r1_in, l2_in + r2_in)
                r2_out = self._boxplus_tf(r1_in, l1_in) + r2_in

                # and re-concatenate output
                r_out = tf.concat([r1_out, r2_out], 1)
                r_out = tf.gather(r_out, ind_inv, axis=1)
                msg_r = msg_r.write((ind_s+1)
                                     + ind_it*(self._n_stages+1), r_out)

            # update right-to-left messages
            for ind_s in range(self._n_stages-1, -1, -1):
                ind_range = np.arange(int(self._n/2))
                ind_1 = ind_range * 2 - np.mod(ind_range, 2**ind_s)
                ind_2 = ind_1 + 2**ind_s
                ind_inv = np.argsort(np.concatenate([ind_1, ind_2], axis=0))

                # load messages
                if ind_s==self._n_stages-1:
                    l1_in = tf.gather(llr_ch, ind_1, axis=1)
                    l2_in = tf.gather(llr_ch, ind_2, axis=1)
                else:
                    l_in = msg_l.read((ind_s+1)+ind_it*(self._n_stages+1))
                    l1_in = tf.gather(l_in, ind_1, axis=1)
                    l2_in = tf.gather(l_in, ind_2, axis=1)

                if ind_s==0:
                    r1_in = tf.gather(msg_r_in, ind_1, axis=1)
                    r2_in = tf.gather(msg_r_in, ind_2, axis=1)
                else:
                    r_in = msg_r.read(ind_s + ind_it*(self._n_stages+1))
                    r1_in = tf.gather(r_in, ind_1, axis=1)
                    r2_in = tf.gather(r_in, ind_2, axis=1)

                # node update functions
                l1_out = self._boxplus_tf(l1_in, l2_in + r2_in)
                l2_out = self._boxplus_tf(r1_in, l1_in) + l2_in

                l_out = tf.concat([l1_out, l2_out], 1)
                l_out = tf.gather(l_out, ind_inv, axis=1)
                msg_l = msg_l.write(ind_s + ind_it*(self._n_stages+1), l_out)

        # recover u_hat
        u_hat = tf.gather(msg_l.read((num_iter-1)*(self._n_stages+1)),
                          self._info_pos, axis=1)
        # if active, hard-decide output bits
        if self._hard_out:
            u_hat = tf.where(u_hat>0,
                             tf.constant(0., dtype=self.rdtype),
                             tf.constant(1., dtype=self.rdtype))
        else: # re-transform soft output to logits (instead of llrs)
            u_hat = -1. * u_hat
        return u_hat

    ########################
    # Sionna Block functions
    ########################

    def build(self, input_shape):
        """Build and check if shape of input is invalid."""
        if input_shape[-1]!=self._n:
            raise ValueError("Invalid input shape")

    def call(self, llr_ch):
        """Iterative BP decoding function.

        This function performs `num_iter` belief propagation decoding iterations
        and returns the estimated information bits.

        Args:
            llr_ch (tf.float): Tensor of shape `[...,n]` containing the
                channel logits/llr values.

        Returns:
            `tf.float`: Tensor of shape `[...,k]` containing
                bit-wise soft-estimates (or hard-decided bit-values) of all
                ``k`` information bits.

        Note:
            This function recursively unrolls the BP decoding graph, thus,
            for larger values of ``n`` or more iterations, building the
            decoding graph can become time and memory consuming.
        """

        # Reshape inputs to [-1, n]
        input_shape = llr_ch.shape
        new_shape = [-1, self._n]
        llr_ch = tf.reshape(llr_ch, new_shape)

        llr_ch = -1. * llr_ch # logits are converted into "true" llrs

        # and decode
        u_hat = self._decode_bp(llr_ch, self._num_iter)

        # and reconstruct input shape
        output_shape = input_shape.as_list()
        output_shape[-1] = self.k
        output_shape[0] = -1 # first dim can be dynamic (None)
        u_hat_reshape = tf.reshape(u_hat, output_shape)
        return u_hat_reshape


class Polar5GDecoder(Block):
    # pylint: disable=line-too-long
    """Wrapper for 5G compliant decoding including rate-recovery and CRC removal.

    Parameters
    ----------
    enc_polar: Polar5GEncoder
        Instance of the :class:`~sionna.phy.fec.polar.encoding.Polar5GEncoder`
        used for encoding including rate-matching.

    dec_type: "SC" (default) | "SCL" | "hybSCL" | "BP"
        Defining the decoder to be used.
        Must be one of the following `{"SC", "SCL", "hybSCL", "BP"}`.

    list_size: int, (default 8)
        Defining the list size `iff` list-decoding is used.
        Only required for ``dec_types`` `{"SCL", "hybSCL"}`.

    num_iter: int, (default 20)
        Defining the number of BP iterations. Only required for ``dec_type``
        `"BP"`.

    return_crc_status: `bool`, (default `False`)
        If `True`,  the decoder additionally returns the CRC status indicating
        if a codeword was (most likely) correctly recovered.

    precision : `None` (default) | 'single' | 'double'
        Precision used for internal calculations and outputs.
        If set to `None`, :py:attr:`~sionna.phy.config.precision` is used.

    Input
    -----
    llr_ch: [...,n], tf.float
        Tensor containing the channel logits/llr values.

    Output
    ------
    b_hat : [...,k], tf.float
        Binary tensor containing hard-decided estimations of all `k`
        information bits.

    crc_status : [...], tf.bool
        CRC status indicating if a codeword was (most likely) correctly
        recovered. This is only returned if ``return_crc_status`` is True.
        Note that false positives are possible.

    Note
    ----
    This block supports the uplink and downlink Polar rate-matching scheme
    without `codeword segmentation`.

    Although the decoding `list size` is not provided by 3GPP
    [3GPPTS38212]_, the consortium has agreed on a `list size` of 8 for the
    5G decoding reference curves [Bioglio_Design]_.

    All list-decoders apply `CRC-aided` decoding, however, the non-list
    decoders (`"SC"` and `"BP"`) cannot materialize the CRC leading to an
    effective rate-loss.

    """

    def __init__(self,
                 enc_polar,
                 dec_type="SC",
                 list_size=8,
                 num_iter=20,
                 return_crc_status=False,
                 precision=None,
                 **kwargs):

        super().__init__(precision=precision, **kwargs)

        if not isinstance(enc_polar, Polar5GEncoder):
            raise TypeError("enc_polar must be Polar5GEncoder.")
        if not isinstance(dec_type, str):
            raise TypeError("dec_type must be str.")

        # list_size and num_iter are not checked here (done during decoder init)

        # Store internal attributes
        self._n_target = enc_polar.n_target
        self._k_target = enc_polar.k_target
        self._n_polar = enc_polar.n_polar
        self._k_polar = enc_polar.k_polar
        self._k_crc = enc_polar.enc_crc.crc_length
        self._bil = enc_polar._channel_type == "uplink"
        self._iil = enc_polar._channel_type == "downlink"
        self._llr_max = 100 # Internal max LLR value (for punctured positions)
        self._enc_polar = enc_polar
        self._dec_type = dec_type

        # Initialize the de-interleaver patterns
        self._init_interleavers()

        # Initialize decoder
        if dec_type=="SC":
            print("Warning: 5G Polar codes use an integrated CRC that " \
                  "cannot be materialized with SC decoding and, thus, " \
                  "causes a degraded performance. Please consider SCL " \
                  "decoding instead.")
            self._polar_dec = PolarSCDecoder(self._enc_polar.frozen_pos,
                                             self._n_polar)
        elif dec_type=="SCL":
            self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
                                self._n_polar,
                                crc_degree=self._enc_polar.enc_crc.crc_degree,
                                list_size=list_size,
                                ind_iil_inv = self.ind_iil_inv)
        elif dec_type=="hybSCL":
            self._polar_dec = PolarSCLDecoder(self._enc_polar.frozen_pos,
                                self._n_polar,
                                crc_degree=self._enc_polar.enc_crc.crc_degree,
                                list_size=list_size,
                                use_hybrid_sc=True,
                                ind_iil_inv = self.ind_iil_inv)
        elif dec_type=="BP":
            print("Warning: 5G Polar codes use an integrated CRC that " \
                  "cannot be materialized with BP decoding and, thus, " \
                  "causes a degraded performance. Please consider SCL " \
                  " decoding instead.")
            if not isinstance(num_iter, int):
                raise TypeError("num_iter must be int.")
            if num_iter <= 0:
                raise ValueError("num_iter must be positive.")
            self._num_iter = num_iter
            self._polar_dec = PolarBPDecoder(self._enc_polar.frozen_pos,
                                             self._n_polar,
                                             num_iter=num_iter,
                                             hard_out=True)
        else:
            raise ValueError("Unknown value for dec_type.")

        if not isinstance(return_crc_status, bool):
            raise TypeError("return_crc_status must be bool.")

        self._return_crc_status = return_crc_status
        if self._return_crc_status: # init crc decoder
            if dec_type in ("SCL", "hybSCL"):
                # re-use CRC decoder from list decoder
                self._dec_crc = self._polar_dec._crc_decoder
            else: # init new CRC decoder for BP and SC
                self._dec_crc = CRCDecoder(self._enc_polar._enc_crc)

    ###############################
    # Public methods and properties
    ###############################

    @property
    def k_target(self):
        """Number of information bits including rate-matching"""
        return self._k_target

    @property
    def n_target(self):
        """Codeword length including rate-matching"""
        return self._n_target

    @property
    def k_polar(self):
        """Number of information bits of mother Polar code"""
        return self._k_polar

    @property
    def n_polar(self):
        """Codeword length of mother Polar code"""
        return self._n_polar

    @property
    def frozen_pos(self):
        """Frozen positions for Polar decoding"""
        return self._frozen_pos

    @property
    def info_pos(self):
        """Information bit positions for Polar encoding"""
        return self._info_pos

    @property
    def llr_max(self):
        """Maximum LLR value for internal calculations"""
        return self._llr_max

    @property
    def dec_type(self):
        """Decoder type used for decoding as str"""
        return self._dec_type

    @property
    def polar_dec(self):
        """Decoder instance used for decoding"""
        return self._polar_dec

    #################
    # Utility methods
    #################

    def _init_interleavers(self):
        """Initialize inverse interleaver patterns for rate-recovery."""

        # Channel interleaver
        ind_ch_int = self._enc_polar.channel_interleaver(
                                                np.arange(self._n_target))
        self.ind_ch_int_inv = np.argsort(ind_ch_int) # Find inverse perm

        # Sub-block interleaver
        ind_sub_int = self._enc_polar.subblock_interleaving(
                                                np.arange(self._n_polar))
        self.ind_sub_int_inv = np.argsort(ind_sub_int) # Find inverse perm

        # input bit interleaver
        if self._iil:
            self.ind_iil_inv = np.argsort(self._enc_polar.input_interleaver(
                                                np.arange(self._k_polar)))
        else:
            self.ind_iil_inv = None

    ########################
    # Sionna Block functions
    ########################

    def build(self, input_shape):
        """Build and check if shape of input is invalid."""
        if input_shape[-1]!=self._n_target:
            raise ValueError("Invalid input shape.")

    def call(self, llr_ch):
        """Polar decoding and rate-recovery for uplink 5G Polar codes.

        Args:
            llr_ch (tf.float): Tensor of shape `[...,n]` containing the
                channel logits/llr values.

        Returns:
            `tf.float`: Tensor of shape `[...,k]` containing
                hard-decided estimates of all ``k`` information bits.
        """

        input_shape = llr_ch.shape
        new_shape = [-1, self._n_target]
        llr_ch = tf.reshape(llr_ch, new_shape)

        # Note: logits are not inverted here; this is done in the decoder itself

        # 1.) Undo channel interleaving
        if self._bil:
            llr_deint = tf.gather(llr_ch, self.ind_ch_int_inv, axis=1)
        else:
            llr_deint = llr_ch

        # 2.) Remove puncturing, shortening, repetition (see Sec. 5.4.1.2)
        # a) Puncturing: set LLRs to 0
        # b) Shortening: set LLRs to infinity
        # c) Repetition: combine LLRs
        if self._n_target >= self._n_polar:
            # Repetition coding
            # Add the last n_rep positions to the first llr positions
            n_rep = self._n_target - self._n_polar
            llr_1 = llr_deint[:,:n_rep]
            llr_2 = llr_deint[:,n_rep:self._n_polar]
            llr_3 = llr_deint[:,self._n_polar:]
            llr_dematched = tf.concat([llr_1+llr_3, llr_2], 1)
        else:
            if self._k_polar/self._n_target <= 7/16:
                # Puncturing
                # Append n_polar - n_target "zero" llrs to first positions
                llr_zero = tf.zeros([tf.shape(llr_deint)[0],
                                     self._n_polar-self._n_target], self.rdtype)
                llr_dematched = tf.concat([llr_zero, llr_deint], 1)
            else:
                # Shortening
                # Append n_polar - n_target "-infinity" llrs to last positions
                # Remark: we still operate with logits here, thus the neg. sign
                llr_infty = -self._llr_max * tf.ones([tf.shape(llr_deint)[0],
                                                self._n_polar-self._n_target],
                                                self.rdtype)
                llr_dematched = tf.concat([llr_deint, llr_infty], 1)

        # 3.) Remove subblock interleaving
        llr_dec = tf.gather(llr_dematched, self.ind_sub_int_inv, axis=1)

        # 4.) Run main decoder
        u_hat_crc = self._polar_dec(llr_dec)

        # 5.) Shortening should be implicitly recovered by decoder

        # 6.) Remove input bit interleaving for downlink channels only
        if self._iil:
            u_hat_crc = tf.gather(u_hat_crc, self.ind_iil_inv, axis=1)

        # 7.) Evaluate or remove CRC (and PC)
        if self._return_crc_status:
            # for compatibility with SC/BP, a dedicated CRC decoder is
            # used here (instead of accessing the interal SCL)
            u_hat, crc_status = self._dec_crc(u_hat_crc)
        else: # just remove CRC bits
            u_hat = u_hat_crc[:,:-self._k_crc]

        # And reconstruct input shape
        output_shape = input_shape.as_list()
        output_shape[-1] = self._k_target
        output_shape[0] = -1 # First dim can be dynamic (None)
        u_hat_reshape = tf.reshape(u_hat, output_shape)
        # and cast to internal rdtype (as subblocks may have different configs)
        u_hat_reshape = tf.cast(u_hat_reshape, dtype=self.rdtype)

        if self._return_crc_status:
            # reconstruct CRC shape
            output_shape.pop() # remove last dimension
            crc_status = tf.reshape(crc_status, output_shape)
            return u_hat_reshape, crc_status

        else:
            return u_hat_reshape
